import jwt

from datetime import datetime, timedelta
from fastapi import APIRouter, Body, Depends, Request, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel, field_validator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from typing import Annotated

from common import status
from configs import DEMO, ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_MINUTES, ACCESS_TOKEN_CACHE_MINUTES, SECRET_KEY, ALGORITHM, OAUTH_ENABLE, OAUTH2_SCHEME
from core.exception import CustomException
from core.response import SuccessResponse, ErrorResponse
from core.validator import valid_telephone
from database.database import db_getter
from modules.admin.auth import dal, model, param, schema


class LoginForm(BaseModel):
    username: str
    password: str
    method: str = '0'    # 认证方式，0：密码登录
    platform: str = '0'  # 登录平台，0：PC端管理系统

    if method == "1":
        # 重用验证器：https://docs.pydantic.dev/dev-v2/usage/validators/#reuse-validators
        normalize_telephone = field_validator('username')(valid_telephone)


class LoginResult(BaseModel):
    status: bool | None = False
    user: schema.UserPasswordOut | None = None
    msg: str | None = None

    class Config:
        arbitrary_types_allowed = True


class LoginValidation:
    """
    验证用户登录时提交的数据是否有效
    """

    def __init__(self, func):
        self.func = func

    async def __call__(self, data: LoginForm, db: AsyncSession, request: Request) -> LoginResult:
        self.result = LoginResult()
        if data.platform not in ["0", "1"] or data.method not in ["0", "1"]:
            self.result.msg = "无效参数"
            return self.result
        user = await dal.UserDal(db).get_data(username=data.username, v_return_none=True)
        if not user:
            self.result.msg = "用户名或密码错误"
            return self.result

        result = await self.func(self, data=data, user=user, request=request)

        # if REDIS_DB_ENABLE:
        #     count_key = f"{data.telephone}_password_auth" if data.method == '0' else f"{data.telephone}_sms_auth"
        #     count = Count(redis_getter(request), count_key)
        # else:
        count = None

        if not result.status:
            self.result.msg = result.msg
            if not DEMO and count:
                number = await count.add(ex=86400)
                if number >= DEFAULT_AUTH_ERROR_MAX_NUMBER:
                    await count.reset()
                    # 如果等于最大次数，那么就将用户 is_active=False
                    user.is_active = False
                    await db.flush()
        elif not user.is_active:
            self.result.msg = "此用户已被冻结！"
        else:
            if not DEMO and count:
                await count.delete()
            self.result.msg = "OK"
            self.result.status = True
            self.result.user = schema.UserPasswordOut.model_validate(user)
            await dal.UserDal(db).update_login_info(user, request.client.host)
        return self.result


class LoginManage:
    @LoginValidation
    async def password_login(self, data: LoginForm, user: model.AdminUser, **kwargs) -> LoginResult:
        """
        验证用户密码
        """
        result = model.AdminUser.verify_password(data.password, user.password)
        if result:
            return LoginResult(status=True, msg="验证成功")
        return LoginResult(status=False, msg="用户名或密码错误")

    @staticmethod
    def create_token(payload: dict, expires: timedelta = None):
        """
        创建一个生成新的访问令牌的工具函数。

        pyjwt：https://github.com/jpadilla/pyjwt/blob/master/docs/usage.rst
        jwt 博客：https://geek-docs.com/python/python-tutorial/j_python-jwt.html

        #TODO 传入的时间为UTC时间datetime.datetime类型，但是在解码时获取到的是本机时间的时间戳
        """
        if expires:
            expire = datetime.utcnow() + expires
        else:
            expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
        payload.update({"exp": expire})
        encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
        return encoded_jwt


class Auth(BaseModel):
    user: model.AdminUser = None
    db: AsyncSession
    data_range: int | None = None
    dept_ids: list | None = []

    class Config:
        # 接收任意类型
        arbitrary_types_allowed = True


class AuthValidation:
    """
    用于用户每次调用接口时，验证用户提交的token是否正确，并从token中获取用户信息
    """

    # status_code = 401 时，表示强制要求重新登录，因账号已冻结，账号已过期，手机号码错误，刷新token无效等问题导致
    # 只有 code = 401 时，表示 token 过期，要求刷新 token
    # 只有 code = 错误值时，只是报错，不重新登陆
    error_code = status.HTTP_401_UNAUTHORIZED
    warning_code = status.HTTP_ERROR

    # status_code = 403 时，表示强制要求重新登录，因无系统权限，而进入到系统访问等问题导致

    @classmethod
    def validate_token(cls, request: Request, token: str | None) -> tuple[str, bool]:
        """
        验证用户 token
        """
        if not token:
            raise CustomException(
                msg="请您先登录！",
                code=status.HTTP_403_FORBIDDEN,
                status_code=status.HTTP_403_FORBIDDEN
            )
        try:
            payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
            username: str = payload.get("sub")
            exp: int = payload.get("exp")
            is_refresh: bool = payload.get("is_refresh")
            password: bool = payload.get("password")
            if not username or is_refresh or not password:
                raise CustomException(
                    msg="未认证，请您重新登录",
                    code=status.HTTP_403_FORBIDDEN,
                    status_code=status.HTTP_403_FORBIDDEN
                )
            # 计算当前时间 + 缓冲时间是否大于等于 JWT 过期时间
            buffer_time = (datetime.now() + timedelta(minutes=ACCESS_TOKEN_CACHE_MINUTES)).timestamp()
            # print("过期时间", exp, datetime.fromtimestamp(exp))
            # print("当前时间", buffer_time, datetime.fromtimestamp(buffer_time))
            # print("剩余时间", exp - buffer_time)
            if buffer_time >= exp:
                request.scope["if-refresh"] = 1
            else:
                request.scope["if-refresh"] = 0
        except (jwt.exceptions.InvalidSignatureError, jwt.exceptions.DecodeError):
            raise CustomException(
                msg="无效认证，请您重新登录",
                code=status.HTTP_403_FORBIDDEN,
                status_code=status.HTTP_403_FORBIDDEN
            )
        except jwt.exceptions.ExpiredSignatureError:
            raise CustomException(msg="认证已失效，请您重新登录", code=cls.error_code, status_code=cls.error_code)
        return username, password

    @classmethod
    async def validate_user(cls, request: Request, user: model.AdminUser, db: AsyncSession, is_all: bool = True) -> Auth:
        """
        验证用户信息
        :param request:
        :param user:
        :param db:
        :param is_all: 是否所有人访问，不加权限
        :return:
        """
        if user is None:
            raise CustomException(msg="未认证，请您重新登陆", code=cls.error_code, status_code=cls.error_code)
        elif not user.is_active:
            raise CustomException(msg="用户已被冻结！", code=cls.error_code, status_code=cls.error_code)
        request.scope["id"] = user.id
        request.scope["username"] = user.username
        request.scope["fullname"] = user.fullname
        request.scope["telephone"] = user.telephone
        try:
            request.scope["body"] = await request.body()
        except RuntimeError:
            request.scope["body"] = "获取失败"
        if is_all:
            return Auth(user=user, db=db)
        data_range, dept_ids = await cls.get_user_data_range(user, db)
        return Auth(user=user, db=db, data_range=data_range, dept_ids=dept_ids)

    @classmethod
    def get_user_permissions(cls, user: model.AdminUser) -> set:
        """
        获取员工用户所有权限列表
        :param user: 用户实例
        :return:
        """
        if user.is_admin():
            return {'*.*.*'}
        permissions = set()
        for role_obj in user.roles:
            for menu in role_obj.menus:
                if menu.perms and not menu.disabled:
                    permissions.add(menu.perms)
        return permissions

    @classmethod
    async def get_user_data_range(cls, user: model.AdminUser, db: AsyncSession) -> tuple:
        """
        获取用户数据范围
        0 仅本人数据权限  create_user_id 查询
        1 本部门数据权限  部门 id 左连接查询
        2 本部门及以下数据权限 部门 id 左连接查询
        3 自定义数据权限  部门 id 左连接查询
        4 全部数据权限  无
        :param user:
        :param db:
        :return:
        """
        if user.is_admin():
            return 4, ["*"]
        data_range = max([i.data_range for i in user.roles])
        dept_ids = set()
        if data_range == 0:
            pass
        elif data_range == 1:
            for dept in user.depts:
                dept_ids.add(dept.id)
        elif data_range == 2:
            # 递归获取部门列表
            dept_ids = await dal.UserDal(db).recursion_get_dept_ids(user)
        elif data_range == 3:
            for role_obj in user.roles:
                for dept in role_obj.depts:
                    dept_ids.add(dept.id)
        elif data_range == 4:
            dept_ids.add("*")
        return data_range, list(dept_ids)


class OpenAuth(AuthValidation):

    """
    开放认证，无认证也可以访问
    认证了以后可以获取到用户信息，无认证则获取不到
    """

    async def __call__(
        self,
        request: Request,
        token: Annotated[str, Depends(OAUTH2_SCHEME)],
        db: AsyncSession = Depends(db_getter)
    ):
        """
        每次调用依赖此类的接口会执行该方法
        """
        if not OAUTH_ENABLE:
            return Auth(db=db)
        try:
            username, password = self.validate_token(request, token)
            user = await dal.UserDal(db).get_data(username=username, password=password, v_return_none=True)
            return await self.validate_user(request, user, db, is_all=True)
        except CustomException:
            return Auth(db=db)


class AllUserAuth(AuthValidation):

    """
    支持所有用户认证
    获取用户基本信息
    """

    async def __call__(
        self,
        request: Request,
        token: str = Depends(OAUTH2_SCHEME),
        db: AsyncSession = Depends(db_getter)
    ):
        """
        每次调用依赖此类的接口会执行该方法
        """
        if not OAUTH_ENABLE:
            return Auth(db=db)
        username, password = self.validate_token(request, token)
        user = await dal.UserDal(db).get_data(username=username, password=password, v_return_none=True)
        return await self.validate_user(request, user, db, is_all=True)


class FullAdminAuth(AuthValidation):

    """
    只支持员工用户认证
    获取员工用户完整信息
    如果有权限，那么会验证该用户是否包括权限列表中的其中一个权限
    """

    def __init__(self, permissions: list[str] | None = None):
        if permissions:
            self.permissions = set(permissions)
        else:
            self.permissions = None

    async def __call__(
        self,
        request: Request,
        token: str = Depends(OAUTH2_SCHEME),
        db: AsyncSession = Depends(db_getter)
    ) -> Auth:
        """
        每次调用依赖此类的接口会执行该方法
        """
        if not OAUTH_ENABLE:
            return Auth(db=db)
        username, password = self.validate_token(request, token)
        options = [
            joinedload(model.AdminUser.roles).subqueryload(model.AdminRole.menus),
            joinedload(model.AdminUser.roles).subqueryload(model.AdminRole.depts),
            joinedload(model.AdminUser.depts)
        ]
        user = await dal.UserDal(db).get_data(
            username=username,
            password=password,
            v_return_none=True,
            v_options=options
        )
        result = await self.validate_user(request, user, db, is_all=False)
        permissions = self.get_user_permissions(user)
        if permissions != {'*.*.*'} and self.permissions:
            if not (self.permissions & permissions):
                raise CustomException(msg="无权限操作", code=status.HTTP_403_FORBIDDEN)
        return result


router = APIRouter(
    prefix='/auth',
    tags = ['用户登录']
)

@router.post("/login", summary="账号密码登录", description="PC端登录通道，限制最多输错次数，达到最大值后将is_active=False")
async def login_for_access_token(
    request: Request,
    data: LoginForm,
    manage: LoginManage = Depends(),
    db: AsyncSession = Depends(db_getter)
):
    try:
        if data.method == "0":
            result = await manage.password_login(data, db, request)
        elif data.method == "1":
            result = await manage.sms_login(data, db, request)
        else:
            raise ValueError("无效参数")

        if not result.status:
            raise ValueError(result.msg)

        access_token = LoginManage.create_token(
            {"sub": result.user.username, "is_refresh": False, "password": result.user.password}
        )
        expires = timedelta(minutes=REFRESH_TOKEN_EXPIRE_MINUTES)
        refresh_token = LoginManage.create_token(
            payload={"sub": result.user.username, "is_refresh": True, "password": result.user.password},
            expires=expires
        )
        resp = {
            "access_token": access_token,
            "refresh_token": refresh_token,
            "token_type": "bearer",
            "is_reset_password": result.user.is_reset_password
        }
        # await AdminLoginRecord.create_login_record(db, data, True, request, resp)
        return SuccessResponse(resp)
    except ValueError as e:
        # await AdminLoginRecord.create_login_record(db, data, False, request, {"message": str(e)})
        return ErrorResponse(msg=str(e))


@router.get("/current/user/info", summary="获取当前用户信息")
async def get_current_user_info(auth: Auth = Depends(FullAdminAuth())):
    result = schema.UserOut.model_validate(auth.user).model_dump()
    result["permissions"] = list(FullAdminAuth.get_user_permissions(auth.user))
    return SuccessResponse(result)


@router.get("/getMenuList", summary="获取当前用户菜单树")
async def get_menu_list(auth: Auth = Depends(FullAdminAuth())):
    return SuccessResponse(await dal.MenuDal(auth.db).get_routers(auth.user))
